{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/oughtinc/ergo/blob/notebooks-readme/notebooks/covid-19-metaculus.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install [Ergo](https://github.com/oughtinc/ergo) (our forecasting library) and a few tools we'll use in this colab:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --quiet poetry  # Fixes https://github.com/python-poetry/poetry/issues/532\n",
    "!pip install --quiet git+https://github.com/oughtinc/ergo.git@6396f5ec4a73a18d36faa1651b1cd9ad852f916e\n",
    "!pip install --quiet pendulum seaborn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext google.colab.data_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import ergo\n",
    "import pendulum\n",
    "import pandas\n",
    "import seaborn\n",
    "\n",
    "from ergo import logistic\n",
    "\n",
    "from types import SimpleNamespace\n",
    "from typing import List\n",
    "from pendulum import DateTime\n",
    "from matplotlib import pyplot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Questions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here are Metaculus ids for the questions we'll load, and some short names that will allow us to associate questions with variables in our model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "question_ids = [3704, 3712, 3713, 3711, 3722, 3761, 3705, 3706]\n",
    "question_names = [\n",
    "  \"WHO Eastern Mediterranean Region on 2020/03/27\",\n",
    "  \"WHO Region of the Americas on 2020/03/27\",\n",
    "  \"WHO Western Pacific Region on 2020/03/27\",\n",
    "  \"WHO South-East Asia Region on 2020/03/27\",\n",
    "  \"South Korea on 2020/03/27\",\n",
    "  \"United Kingdom on 2020/03/27\",\n",
    "  \"WHO African Region on 2020/03/27\",\n",
    "  \"WHO European Region on 2020/03/27\"\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We load the question data from Metaculus:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metaculus = ergo.Metaculus(username=\"ought\", password=\"\")\n",
    "questions = [metaculus.get_question(id, name=name) for id, name in zip(question_ids, question_names)]\n",
    "ergo.MetaculusQuestion.to_dataframe(questions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our most important data is the data about confirmed cases (from Hopkins):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "confirmed_infections = ergo.data.covid19.ConfirmedInfections()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Assumptions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Assumptions are things that should be inferred from data but currently aren't:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assumptions = SimpleNamespace()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll manually add some data about [doubling times](https://ourworldindata.org/coronavirus#the-growth-rate-of-covid-19-deaths) (in days):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assumptions.doubling_time = {\n",
    "  \"World\": 8,\n",
    "  \"China\": 33,\n",
    "  \"Italy\": 4,\n",
    "  \"Iran\": 5,\n",
    "  \"Spain\": 3,\n",
    "  \"France\": 4,\n",
    "  \"United States\": 3,\n",
    "  \"United Kingdom\": 3,\n",
    "  \"South Korea\": 12,\n",
    "  \"Netherlands\": 1,\n",
    "  \"Japan\": 8,\n",
    "  \"Switzerland\": 5,\n",
    "  \"Philippines\": 4,\n",
    "  \"Belgium\": 1,\n",
    "  \"San Marino\": 3,\n",
    "  \"Germany\": 5,\n",
    "  \"Iraq\": 8,\n",
    "  \"Sweden\": 3,\n",
    "  \"Canada\": 2,\n",
    "  \"Algeria\": 4,\n",
    "  \"Australia\": 4,\n",
    "  \"Egypt\": 2,\n",
    "  \"Greece\": 5,\n",
    "  \"Indonesia\": 7,\n",
    "  \"Poland\": 5\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To estimate doubling times for places where we don't have data we'll specify which places are similar to which other places:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assumptions.similar_areas = {\n",
    "  \"Bay Area\": [\"United States\", \"Italy\"],\n",
    "  \"San Francisco\": [\"United States\", \"Italy\"],\n",
    "  \"WHO Eastern Mediterranean Region\": [\"Iraq\", \"Iran\"],\n",
    "  \"WHO Region of the Americas\": [\"United States\"],\n",
    "  \"WHO Western Pacific Region\": [\"United States\", \"Italy\", \"Spain\", \"South Korea\"],\n",
    "  \"WHO South-East Asia Region\": [\"South Korea\", \"China\", \"Japan\", \"Philippines\", \"Indonesia\"],\n",
    "  \"South Korea\": [\"South Korea\"],\n",
    "  \"United Kingdom\": [\"United Kingdom\"],\n",
    "  \"WHO African Region\": [\"Algeria\"],\n",
    "  \"WHO European Region\": [\"Italy\", \"Spain\",\"France\",\"Germany\",\"Greece\"], # \"Belgium\",\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Main model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Area = str\n",
    "\n",
    "def get_doubling_time(area: Area):\n",
    "  similar_areas = assumptions.similar_areas[area]\n",
    "  doubling_times = [assumptions.doubling_time[proxy] for proxy in similar_areas]\n",
    "  proxy_doubling_time = ergo.random_choice(doubling_times)\n",
    "  doubling_time = ergo.lognormal_from_interval(proxy_doubling_time - 0.5, proxy_doubling_time + 0.5)\n",
    "  return doubling_time\n",
    "\n",
    "def model(start: DateTime, end: DateTime, areas: List[Area]):\n",
    "  for area in areas:\n",
    "    doubling_time = get_doubling_time(area)\n",
    "    confirmed = confirmed_infections(area, start)\n",
    "    for i in range((end - start).days):\n",
    "      date = start.add(days=i)\n",
    "      confirmed = confirmed * 2**(1 / doubling_time)\n",
    "      ergo.tag(confirmed, f\"{area} on {date.format('YYYY/MM/DD')}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model parameters\n",
    "start_date = pendulum.now(tz=\"US/Pacific\").subtract(days = 10)\n",
    "end_date = max(question.resolve_time for question in questions).add(days = 3)\n",
    "areas = [re.match(\"(.*)? on\", name).groups()[0] for name in question_names]\n",
    "\n",
    "# Get samples from model for all variables\n",
    "samples = ergo.run(lambda: model(start_date, end_date, areas), num_samples=5000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Look at raw samples from the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Summary stats:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples.describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plot some marginals:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for question in questions:\n",
    "  pyplot.figure()\n",
    "  seaborn.distplot(samples[question.name])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Submit predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convert samples to Metaculus distributions and submit:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for question in questions:\n",
    "  if question.name in samples:\n",
    "    params = question.submit_from_samples(samples[question.name])\n",
    "    print(f\"Submitted Logistic{params} for {question.name}\")\n",
    "  else:\n",
    "    print(f\"No predictions for {question.name}\")"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
